{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Heart Rate Estimation using TAO HeartRateNet\n",
    "\n",
    "Transfer learning is the process of transferring learned features from one application to another. It is a commonly used training technique where you use a model trained on one task and re-train to use it on a different task. \n",
    "\n",
    "Train Adapt Optimize (TAO) Toolkit is a simple and easy-to-use Python based AI toolkit for taking purpose-built AI models and customizing them with users' own data.\n",
    "\n",
    "<img align=\"center\" src=\"https://d29g4g2dyqv443.cloudfront.net/sites/default/files/akamai/TAO/tlt-tao-toolkit-bring-your-own-model-diagram.png\" width=\"1080\"> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning Objectives\n",
    "In this notebook, you will learn how to leverage the simplicity and convenience of TAO to:\n",
    "\n",
    "* Take a pretrained model and train a model on the COHFACE dataset\n",
    "* Run Inference on the trained model\n",
    "* Export the retrained model to a .etlt file for deployment to DeepStream SDK\n",
    "\n",
    "At the end of this notebook, you will have generated a trained and optimized `heartratenet` model, \n",
    "which you may deploy via [DeepStream](https://developer.nvidia.com/deepstream-sdk).\n",
    "\n",
    "### Table of Contents\n",
    "\n",
    "This notebook shows an example of non-invasive heart rate estimation using the Train Adapt Optimize (TAO) Toolkit.\n",
    "\n",
    "0. [Set up env variables, map drives, and install dependencies](#head-0)\n",
    "1. [Install the TAO launcher](#head-1)\n",
    "2. [Prepare dataset and pre-trained model](#head-2) <br>\n",
    "    2.1 [Verify downloaded dataset](#head-2-1) <br>\n",
    "    2.2 [Extract video data to image data](#head-2-2) <br>\n",
    "    2.2 [Process the extracted data](#head-2-3) <br>\n",
    "    2.2 [Download pre-trained model](#head-2-4) <br>\n",
    "3. [Generate tfrecords from RGB videos](#head-3) <br>\n",
    "    3.1 [Download haarcascade classifier](#head-3-1) <br>\n",
    "    3.2 [Generate tfrecords](#head-3-2) <br>\n",
    "4. [Provide training specification](#head-4) <br>\n",
    "5. [Run TAO training](#head-5) <br>\n",
    "6. [Evaluate the trained model](#head-6) <br>\n",
    "7. [Inference](#head-7) <br>\n",
    "8. [Export](#head-8) <br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Set up env variables, map drives and install dependencies <a class=\"anchor\" id=\"head-0\"></a>\n",
    "\n",
    "When using the purpose-built pretrained models from NGC, please make sure to set the `$KEY` environment variable to the key as mentioned in the model overview. Failing to do so, can lead to errors when trying to load them as pretrained models.\n",
    "\n",
    "The following notebook requires the user to set an env variable called the `$LOCAL_PROJECT_DIR` as the path to the users' workspace. Please note that the dataset to run this notebook is expected to reside in the `$LOCAL_PROJECT_DIR/heartratenet/data`, while the TAO experiment generated collaterals will be output to `$LOCAL_PROJECT_DIR/heartratenet`. More information on how to set up the dataset and the supported steps in the TAO workflow are provided in the subsequent cells.\n",
    "\n",
    "*Note: This notebook currently is by default set up to run training using 1 GPU. To use more GPU's please update the env variable `$NUM_GPUS` accordingly*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting up env variables for cleaner command-line commands.\n",
    "import os\n",
    "\n",
    "%env KEY=nvidia_tlt\n",
    "%env NUM_GPUS=1\n",
    "%env USER_EXPERIMENT_DIR=/workspace/tao-experiments/heartratenet\n",
    "%env HEARTRATENET_DATA=/workspace/tao-experiments/heartratenet/data\n",
    "\n",
    "# Set this path if you don't run the notebook from the samples directory.\n",
    "# %env NOTEBOOK_ROOT=~/tao-samples/heartratenet\n",
    "\n",
    "# Please define this local project directory that needs to be mapped to the TAO docker session.\n",
    "# The dataset is expected to be present in $LOCAL_PROJECT_DIR/heartratenet/data, while the results for the steps\n",
    "# in this notebook will be stored at $LOCAL_PROJECT_DIR/heartratenet\n",
    "# !PLEASE MAKE SURE TO UPDATE THIS PATH!.\n",
    "%env LOCAL_PROJECT_DIR=/path/to/local/experiments\n",
    "\n",
    "# $PROJECT_DIR is the path to the sample notebook folder and the dependency folder\n",
    "# $PROJECT_DIR/deps should exist for dependency installation\n",
    "%env PROJECT_DIR=/path/to/local/samples_dir\n",
    "\n",
    "os.environ[\"LOCAL_DATA_DIR\"] = os.path.join(\n",
    "    os.getenv(\"LOCAL_PROJECT_DIR\", os.getcwd()),\n",
    "    \"heartratenet/data\"\n",
    ")\n",
    "os.environ[\"LOCAL_EXPERIMENT_DIR\"] = os.path.join(\n",
    "    os.getenv(\"LOCAL_PROJECT_DIR\", os.getcwd()),\n",
    "    \"heartratenet\"\n",
    ")\n",
    "\n",
    "# The sample spec files are present in the same path as the downloaded samples.\n",
    "os.environ[\"LOCAL_SPECS_DIR\"] = os.path.join(\n",
    "    os.getenv(\"NOTEBOOK_ROOT\", os.getcwd()),\n",
    "    \"specs\"\n",
    ")\n",
    "%env SPECS_DIR=/workspace/tao-experiments/heartratenet/specs\n",
    "\n",
    "# Showing list of specification files.\n",
    "!ls -rlt $LOCAL_SPECS_DIR\n",
    "!mkdir -p $LOCAL_EXPERIMENT_DIR/model/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The cell below maps the project directory on your local host to a workspace directory in the TAO docker instance, so that the data and the results are mapped from in and out of the docker. For more information please refer to the [launcher instance](https://docs.nvidia.com/tao/tao-toolkit/text/tao_launcher.html) in the user guide.\n",
    "\n",
    "When running this cell on AWS, update the drive_map entry with the dictionary defined below, so that you don't have permission issues when writing data into folders created by the TAO docker.\n",
    "\n",
    "```json\n",
    "drive_map = {\n",
    "    \"Mounts\": [\n",
    "            # Mapping the data directory\n",
    "            {\n",
    "                \"source\": os.environ[\"LOCAL_PROJECT_DIR\"],\n",
    "                \"destination\": \"/workspace/tao-experiments\"\n",
    "            },\n",
    "            # Mapping the specs directory.\n",
    "            {\n",
    "                \"source\": os.environ[\"LOCAL_SPECS_DIR\"],\n",
    "                \"destination\": os.environ[\"SPECS_DIR\"]\n",
    "            },\n",
    "        ],\n",
    "    \"DockerOptions\": {\n",
    "        \"user\": \"{}:{}\".format(os.getuid(), os.getgid())\n",
    "    }\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping up the local directories to the TAO docker.\n",
    "import json\n",
    "mounts_file = os.path.expanduser(\"~/.tao_mounts.json\")\n",
    "\n",
    "# Define the dictionary with the mapped drives\n",
    "drive_map = {\n",
    "    \"Mounts\": [\n",
    "        # Mapping the data directory\n",
    "        {\n",
    "            \"source\": os.environ[\"LOCAL_PROJECT_DIR\"],\n",
    "            \"destination\": \"/workspace/tao-experiments\"\n",
    "        },\n",
    "        # Mapping the specs directory.\n",
    "        {\n",
    "            \"source\": os.environ[\"LOCAL_SPECS_DIR\"],\n",
    "            \"destination\": os.environ[\"SPECS_DIR\"]\n",
    "        },\n",
    "    ]\n",
    "}\n",
    "\n",
    "# Writing the mounts file.\n",
    "with open(mounts_file, \"w\") as mfile:\n",
    "    json.dump(drive_map, mfile, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat ~/.tao_mounts.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install requirement\n",
    "!pip3 install Cython==0.29.36\n",
    "!pip3 install -r $PROJECT_DIR/deps/requirements-pip.txt\n",
    "!pip3 install pyyaml"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The heartratenet api uses the `$DATAIO_SPEC` and `$TRAIN_SPEC` yaml files to set up directories. \n",
    "\n",
    "* `$DATAIO_SPEC` is `$LOCAL_SPECS_DIR/heartratenet_data_generation.yaml`\n",
    "* `$TRAIN_SPEC` is `$LOCAL_SPECS_DIR/heartratenet_tlt_pretrain.yaml`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up dataio and train experiment spec path and check if they exist\n",
    "\n",
    "os.environ[\"DATAIO_SPEC\"] = os.path.join(os.environ[\"SPECS_DIR\"], 'heartratenet_data_generation.yaml')\n",
    "os.environ[\"TRAIN_SPEC\"] = os.path.join(os.environ[\"SPECS_DIR\"], 'heartratenet_tlt_pretrain.yaml')\n",
    "os.environ[\"LOCAL_DATAIO_SPEC\"] = os.path.join(os.environ[\"LOCAL_SPECS_DIR\"], 'heartratenet_data_generation.yaml')\n",
    "os.environ[\"LOCAL_TRAIN_SPEC\"] = os.path.join(os.environ[\"LOCAL_SPECS_DIR\"], 'heartratenet_tlt_pretrain.yaml')\n",
    "\n",
    "!if [ ! -f $LOCAL_DATAIO_SPEC ]; then echo \"Dataio spec file not found.\"; else echo \"Found dataio spec file.\";fi\n",
    "!if [ ! -f $LOCAL_TRAIN_SPEC ]; then echo \"Train spec file not found.\"; else echo \"Found train spec file.\";fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we have to make sure the environment paths selected above match the inputs to the api.\n",
    "Go to `$DATAIO_SPEC` file and change `input_directory_path` and `data_directory_output_path` to the path specified in `$HEARTRATENET_DATA`.\n",
    "Go to `$TRAIN_SPEC` file and change the `results_dir` to `$USER_EXPERIMENT_DIR` and also change the `checkpoint_dir` accordingly. This file is the input to the model. Next, change the `tfrecords_directory_path` to `$HEARTRATENET_DATA`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check to see if spec files are found and is updated.\n",
    "try:\n",
    "    from yaml import load, SafeLoader\n",
    "    import os\n",
    "    from os.path import normpath\n",
    "    \n",
    "    with open(os.environ.get(\"LOCAL_DATAIO_SPEC\")) as f:\n",
    "        print('dataio spec found')\n",
    "        dataio_args = load(f, Loader = SafeLoader)\n",
    "    if normpath(dataio_args['input_directory_path'])!= normpath(os.environ.get(\"HEARTRATENET_DATA\")) or normpath(dataio_args['data_directory_output_path']) != normpath(os.environ.get(\"HEARTRATENET_DATA\")):\n",
    "        print(normpath(dataio_args['input_directory_path']), os.environ.get(\"HEARTRATENET_DATA\") )\n",
    "        print('Please update input_directory_path and data_directory_output_path')\n",
    "except:\n",
    "    print('Dataio spec is not found, please ensure there is dataio spec in proper folder')\n",
    "    \n",
    "try:\n",
    "    from yaml import load\n",
    "    import os\n",
    "    from os.path import normpath, join\n",
    "    \n",
    "    with open(os.environ.get(\"LOCAL_TRAIN_SPEC\")) as f:\n",
    "        print('train spec found')\n",
    "        train_args = load(f, Loader=SafeLoader)\n",
    "        \n",
    "    if normpath(train_args['results_dir']) != normpath(os.environ.get(\"USER_EXPERIMENT_DIR\")):\n",
    "        print('Please update results_dir')\n",
    "        \n",
    "    if normpath(train_args['dataloader']['dataset_info']['tfrecords_directory_path'])!= normpath(os.environ.get(\"HEARTRATENET_DATA\")):\n",
    "        print('Please update the tfrecords_directory_path')\n",
    "except:\n",
    "    print('Train spec is not found, please ensure there is train spec in proper folder')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Install the TAO launcher <a class=\"anchor\" id=\"head-1\"></a>\n",
    "The TAO launcher is a python package distributed as a python wheel listed in PyPI. You may install the launcher by executing the following cell.\n",
    "\n",
    "Please note that TAO Toolkit recommends users run the TAO launcher in a virtual env with python 3.6.9. You may follow the instruction on this [page](https://virtualenvwrapper.readthedocs.io/en/latest/install.html) to set up a python virtual env using the `virtualenv` and `virtualenvwrapper` packages. Once you have set up virtualenvwrapper, please set the version of python to be used in the virtual env by using the `VIRTUALENVWRAPPER_PYTHON` variable. You may do so by running\n",
    "\n",
    "```sh\n",
    "export VIRTUALENVWRAPPER_PYTHON=/path/to/bin/python3.x\n",
    "```\n",
    "where x >= 6 and <= 8\n",
    "\n",
    "We recommend performing this step first and then launching the notebook from the virtual environment. In addition to installing TAO python package, please make sure of the following software requirements:\n",
    "* python >=3.7, <=3.10.x\n",
    "* docker-ce > 19.03.5\n",
    "* docker-API 1.40\n",
    "* nvidia-container-toolkit > 1.3.0-1\n",
    "* nvidia-container-runtime > 3.4.0-1\n",
    "* nvidia-docker2 > 2.5.0-1\n",
    "* nvidia-driver > 455+\n",
    "\n",
    "Once you have installed the pre-requisites, please log in to the docker registry nvcr.io by following the command below\n",
    "\n",
    "```sh\n",
    "docker login nvcr.io\n",
    "```\n",
    "\n",
    "You will be triggered to enter a username and password. The username is `$oauthtoken` and the password is the API key generated from `ngc.nvidia.com`. Please follow the instructions in the [NGC setup guide](https://docs.nvidia.com/ngc/ngc-overview/index.html#generating-api-key) to generate your own API key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SKIP this step IF you have already installed the TAO launcher wheel.\n",
    "!pip3 install nvidia-tao"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# View the versions of the TAO launcher\n",
    "!tao info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Prepare dataset and pre-trained model <a class=\"anchor\" id=\"head-2\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please download COHFACE public dataset from the following website: https://www.idiap.ch/en/dataset/cohface\n",
    "\n",
    "After downloading the data, please extract the data to cohface folder and place it under `$LOCAL_DATA_DIR`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A. Verify downloaded dataset <a class=\"anchor\" id=\"head-2-1\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the dataset is present.\n",
    "!mkdir -p $LOCAL_DATA_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!if [ ! -d $LOCAL_DATA_DIR/cohface ]; then echo \"Data folder not found, please download.\"; else echo \"Data folder found.\";fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Process the extracted data <a class=\"anchor\" id=\"head-2-3\"></a>\n",
    "\n",
    "The `dataio` module for heartratenet expects the data to be formatted in a predefined format.\n",
    "\n",
    "The `dataio` spec file specifies the folders to be read in the three lists train_subjects, validation_subjects and test_subjects.\n",
    "Place the video file under each subject in a folder named images. The path is `$LOCAL_DATA_DIR/subject_folder`.\n",
    "\n",
    "The ground truth is expected in the following format. For the RGB camera feed, `image_timestamps.csv` consists of frame ID and corresponding timestamp in rows `ID`,`Time`. For the pulse readings, `ground_truth.csv` consists of a timestamp and the corresponding ppg reading in rows `Time`,`PulseWaveform`. The heart rate is predicted as the dominant frequency of the ppg signal. The API takes care of sampling differences between the RGB and PPG signals. COHFACE dataset has 40 subjects, use `start_subject_id` and `end_subject_id` as input arguments to the following script to specify subjects to process, `process_cohface.py` process subjects in range `[start_subject_id, end_subject_id)`\n",
    "\n",
    "The following block will process the COHFACE dataset into a format consistent with heartratenet api."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python process_cohface.py -i $LOCAL_DATA_DIR/cohface/ \\\n",
    "                           -o $LOCAL_DATA_DIR/cohface_processed \\\n",
    "                           -start_subject_id 1 \\\n",
    "                           -end_subject_id 41"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### C. Download pre-trained model <a class=\"anchor\" id=\"head-2-4\"></a>\n",
    "\n",
    "Please follow the instructions in the following to download and verify the pretrained model for heartratenet.\n",
    "\n",
    "For HeartRateNet pretrained model please download model: `nvidia/tao/heartratenet:trainable_v2.0`.\n",
    "\n",
    "After download the pre-trained model, please place the files in `$LOCAL_EXPERIMENT_DIR/pretrain_models`\n",
    "You will then have the following path\n",
    "\n",
    "* pretrained model in `$LOCAL_EXPERIMENT_DIR/pretrain_models/heartratenet_vtrainable_v2.0/model.tlt`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Installing NGC CLI on the local machine.\n",
    "## Download and install\n",
    "%env CLI=ngccli_cat_linux.zip\n",
    "!mkdir -p $LOCAL_PROJECT_DIR/ngccli\n",
    "\n",
    "# Remove any previously existing CLI installations\n",
    "!rm -rf $LOCAL_PROJECT_DIR/ngccli/*\n",
    "!wget \"https://ngc.nvidia.com/downloads/$CLI\" -P $LOCAL_PROJECT_DIR/ngccli\n",
    "!unzip -u \"$LOCAL_PROJECT_DIR/ngccli/$CLI\" -d $LOCAL_PROJECT_DIR/ngccli/\n",
    "!rm $LOCAL_PROJECT_DIR/ngccli/*.zip \n",
    "os.environ[\"PATH\"]=\"{}/ngccli/ngc-cli:{}\".format(os.getenv(\"LOCAL_PROJECT_DIR\", \"\"), os.getenv(\"PATH\", \"\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List models available in the model registry.\n",
    "!ngc registry model list nvidia/tao/heartratenet:*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the target destination to download the model.\n",
    "!mkdir -p $LOCAL_EXPERIMENT_DIR/pretrain_models/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the pretrained model from NGC\n",
    "!ngc registry model download-version nvidia/tao/heartratenet:trainable_v2.0 \\\n",
    "    --dest $LOCAL_EXPERIMENT_DIR/pretrain_models/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -rlt $LOCAL_EXPERIMENT_DIR/pretrain_models/heartratenet_vtrainable_v2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the dataset is present\n",
    "!if [ ! -f $LOCAL_EXPERIMENT_DIR/pretrain_models/heartratenet_vtrainable_v2.0/model.tlt ]; then echo 'Pretrain model file not found, please download.'; else echo 'Found Pretrain model file.';fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Generate tfrecords from RGB videos <a class=\"anchor\" id=\"head-3\"></a>\n",
    "* Download haarcascade classifier and prepare directory\n",
    "* Generate required motion and appearance maps for attention network.\n",
    "* Create the tfrecords using the tao command"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A. Download haarcascade classifier <a class=\"anchor\" id=\"head-3-1\"></a>\n",
    "Obtain the haarcascade classifer from (https://github.com/opencv/opencv/blob/master/data/haarcascades/haarcascade_frontalface_default.xml)\n",
    "\n",
    "After downloading the haarcascade classifier, please place `haarcascade_frontalface_default.xml` in `$LOCAL_DATA_DIR`\n",
    "\n",
    "You will have the following path\n",
    "\n",
    "* haarcascade file in\n",
    "`$LOCAL_DATA_DIR/haarcascade_frontalface_default.xml`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://github.com/opencv/opencv/raw/master/data/haarcascades/haarcascade_frontalface_default.xml -P $LOCAL_DATA_DIR/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: Please make sure the file has been downloaded successfully, failure to do so will result in the rest of the notebook not being operational."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check to see if haar classifier is present.\n",
    "!if [ ! -f $LOCAL_DATA_DIR/haarcascade_frontalface_default.xml ]; then echo \"Classifier not found, please ensure classifier is in proper file\"; else echo \"Classifier found.\";fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Generate tfrecords <a class=\"anchor\" id=\"head-3-2\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tao model heartratenet dataset_convert --experiment_spec_file $DATAIO_SPEC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check to see if tfrecords are present.\n",
    "!if [ ! -f $LOCAL_DATA_DIR/train.tfrecord ]; then echo \"Did not find training file, please ensure training record is generated.\"; else echo \"Found training record\";fi\n",
    "!if [ ! -f $LOCAL_DATA_DIR/validation.tfrecord ]; then echo \"Did not find validation file, please ensure validation record is generated.\"; else echo \"Found validation record\";fi\n",
    "!if [ ! -f $LOCAL_DATA_DIR/test.tfrecord ]; then echo \"Did not find test file, please ensure test record is generated.\"; else echo \"Found test record\";fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Provide training specification <a class=\"anchor\" id=\"head-4\"></a>\n",
    "* Tfrecords for the training dataset\n",
    "    * In order to use the newly generated tfrecords for training, update the `tfrecords_directory_path` parameter of `dataset_info` section in the spec file at `$TRAIN_SPEC`\n",
    "* Pre-trained model path\n",
    "    * Update `checkpoint_dir` in the spec file `$TRAIN_SPEC`\n",
    "* Other training (hyper-)parameters such as batch size, number of epochs, learning rate, etc.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat $LOCAL_TRAIN_SPEC"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Run TAO training <a class=\"anchor\" id=\"head-5\"></a>\n",
    "* Provide the sample spec file and the output directory location for models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tao model heartratenet train -e $TRAIN_SPEC \\\n",
    "                        -k $KEY \\\n",
    "                        -r $USER_EXPERIMENT_DIR/model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Evaluate the trained model <a class=\"anchor\" id=\"head-6\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tao model heartratenet evaluate -e $TRAIN_SPEC \\\n",
    "                           -k $KEY \\\n",
    "                           -m $USER_EXPERIMENT_DIR/model/ \\\n",
    "                           -r $USER_EXPERIMENT_DIR/eval_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Inference <a class=\"anchor\" id=\"head-7\"></a>\n",
    "* Ensure you have the required data format as indicated in the model card\n",
    "* Modify `m` to the full model path for evaluation\n",
    "* Modify `subject_infer_dir` and `subject` below to align with your data\n",
    "* Modify `results_dir` to your desired result directory\n",
    "* Modify `fps` to match inference data fps, COHFACE dataset recorded in 20fps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tao model heartratenet inference -m $USER_EXPERIMENT_DIR/model/model.tlt \\\n",
    "                            --subject_infer_dir $HEARTRATENET_DATA \\\n",
    "                            --subject cohface_processed/1/0 \\\n",
    "                            --results_dir $USER_EXPERIMENT_DIR \\\n",
    "                            --fps 20 \\\n",
    "                            -k $KEY \\\n",
    "                            -c channels_first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import IPython.display\n",
    "import PIL.Image\n",
    "\n",
    "subject_infer_dir = os.environ['LOCAL_DATA_DIR']\n",
    "subject = 'Subject1'\n",
    "display_freq = 30\n",
    "results_file = os.path.join(os.environ['LOCAL_EXPERIMENT_DIR'], 'results.txt')\n",
    "with open(results_file, 'r') as file:\n",
    "    results = file.read()\n",
    "\n",
    "subject_work_dir = os.path.join(subject_infer_dir, subject, 'images')\n",
    "cap = cv2.VideoCapture(os.path.join(subject_work_dir, '%04d.bmp'))\n",
    "while cap.isOpened():\n",
    "    ret, frame = cap.read()\n",
    "    if ret:\n",
    "        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
    "        IPython.display.display(PIL.Image.fromarray(frame))\n",
    "    else:\n",
    "        break\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Export <a class=\"anchor\" id=\"head-8\"></a>\n",
    "* Modify `m` to your model directory path\n",
    "* Modify `out_file` to your desired full output path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the model for conversion is present.\n",
    "!if [ ! -f $LOCAL_EXPERIMENT_DIR/model/model.tlt ]; then echo \"Model file not found, please download.\"; else echo \"Model file found.\";fi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tao model heartratenet export -m $USER_EXPERIMENT_DIR/model/model.tlt \\\n",
    "                         -k $KEY \\\n",
    "                         --out_file $USER_EXPERIMENT_DIR/exported_model.tlt"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
